"""
Based on https://github.com/ikostrikov/pytorch-a2c-ppo-acktr
"""
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F

from utils import helpers as utl


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# https://github.com/openai/baselines/blob/master/baselines/common/tf_util.py#L87
def init_normc_(weight, gain=1):
    weight.normal_(0, 1)
    weight *= gain / torch.sqrt(weight.pow(2).sum(1, keepdim=True))

def init(module, weight_init, bias_init, gain=1.0):
    weight_init(module.weight.data, gain=gain)
    bias_init(module.bias.data)
    return module


class NoisyActorCriticRNN(nn.Module):
    def __init__(
        self,
        args,
        # network size
        layers_before_rnn, # list
        rnn_hidden_dim, # int
        layers_after_rnn, # list
        rnn_cell_type, # vanilla, gru
        activation_function,  # tanh, relu, leaky-relu
        initialization_method, # orthogonal, normc
        # states, actions, rewards
        state_dim,
        state_embed_dim,
        action_dim,
        action_embed_dim,
        action_space_type, # Discrete
        reward_dim,
        reward_embed_dim,
        # noise to hidden states
        hidden_noise_std
    ):
        '''
        Separate single-layered noisy RNNs for actor and critic
        '''
        super(NoisyActorCriticRNN, self).__init__()
        
        self.args = args

        self.state_dim = state_dim
        self.action_dim = action_dim
        self.reward_dim = reward_dim
        self.state_embed_dim = state_embed_dim
        self.action_embed_dim = action_embed_dim
        self.reward_embed_dim = reward_embed_dim
        self.hidden_noise_std = hidden_noise_std

        # set activation function
        if activation_function == 'tanh':
            self.activation_function = nn.Tanh()
        elif activation_function == 'relu':
            self.activation_function = nn.ReLU()
        elif activation_function == 'leaky-relu':
            self.activation_function = nn.LeakyReLU()
        else:
            raise ValueError

        # set initialization method
        if initialization_method == 'normc':
            self.init_ = lambda m: init(
                m, init_normc_, 
                lambda x: nn.init.constant_(x, 0), 
                nn.init.calculate_gain(activation_function)
            )
        elif initialization_method == 'orthogonal':
            self.init_ = lambda m: init(
                m, nn.init.orthogonal_, 
                lambda x: nn.init.constant_(x, 0), 
                nn.init.calculate_gain(activation_function)
            )

        # embedder for state, action, reward
        if (self.state_embed_dim!=0) & (self.action_embed_dim!=0) & (self.reward_embed_dim!=0):
            self.state_encoder = utl.FeatureExtractor(self.state_dim, self.state_embed_dim, F.relu)
            self.action_encoder = utl.FeatureExtractor(self.action_dim, self.action_embed_dim, F.relu)
            self.reward_encoder = utl.FeatureExtractor(self.reward_dim, self.reward_embed_dim, F.relu)
            curr_input_dim =  self.state_embed_dim + self.action_embed_dim + self.reward_embed_dim
        else:
            curr_input_dim = self.state_dim + self.action_dim + self.reward_dim

        # initialize actor and critic
        # fully connected layers before the recurrent cell
        self.actor_fc_before_rnn, actor_fc_before_rnn_final_dim = self.gen_fc_layers(
            layers_before_rnn, curr_input_dim
        )
        self.critic_fc_before_rnn, critic_fc_before_rnn_final_dim = self.gen_fc_layers(
            layers_before_rnn, curr_input_dim
        )

        # recurrent layer
        self.rnn_hidden_dim = rnn_hidden_dim
        if rnn_cell_type == 'vanilla':
            self.actor_rnn = nn.RNN(
                input_size=actor_fc_before_rnn_final_dim,
                hidden_size=self.rnn_hidden_dim,
                num_layers=1)
            self.critic_rnn = nn.RNN(
                input_size=critic_fc_before_rnn_final_dim,
                hidden_size=self.rnn_hidden_dim,
                num_layers=1)
        elif rnn_cell_type == 'gru':
            self.actor_rnn = nn.GRU(
                input_size=actor_fc_before_rnn_final_dim,
                hidden_size=self.rnn_hidden_dim,
                num_layers=1)
            self.critic_rnn = nn.GRU(
                input_size=critic_fc_before_rnn_final_dim,
                hidden_size=self.rnn_hidden_dim,
                num_layers=1)
        else:
            raise ValueError(f'invalid rnn_cell_type: {rnn_cell_type}')
        
        for name, param in self.actor_rnn.named_parameters():
            if 'bias' in name:
                nn.init.constant_(param, 0)
            elif 'weight' in name:
                nn.init.orthogonal_(param)
        for name, param in self.critic_rnn.named_parameters():
            if 'bias' in name:
                nn.init.constant_(param, 0)
            elif 'weight' in name:
                nn.init.orthogonal_(param)

        # fully connected layers after the recurrent cell
        curr_input_dim = self.rnn_hidden_dim
        self.actor_fc_after_rnn, actor_fc_after_rnn_final_dim = self.gen_fc_layers(
            layers_after_rnn, curr_input_dim
        )
        self.critic_fc_after_rnn, critic_fc_after_rnn_final_dim = self.gen_fc_layers(
            layers_after_rnn, curr_input_dim
        )
        
        # output layer
        self.actor_output = self.init_(nn.Linear(actor_fc_after_rnn_final_dim, action_dim))
        if action_space_type == 'Discrete':
            self.policy_dist = torch.distributions.Categorical
        else:
            raise NotImplementedError
        self.critic_output = self.init_(nn.Linear(critic_fc_after_rnn_final_dim, 1))

    def gen_fc_layers(self, layers, curr_input_dim):
        fc_layers = nn.ModuleList([])
        for i in range(len(layers)):
            fc = self.init_(nn.Linear(curr_input_dim, layers[i]))
            fc_layers.append(fc)
            curr_input_dim = layers[i]
        return fc_layers, curr_input_dim

    def forward_actor(self, inputs, prev_hidden_states):
        h = inputs
        # fc before RNN
        for i in range(len(self.actor_fc_before_rnn)):
            h = self.actor_fc_before_rnn[i](h)
            h = self.activation_function(h)
        # RNN
        hidden_states, _ = self.actor_rnn(h, prev_hidden_states) # rnn output: output, h_n
        # add noise during training
        if self.actor_rnn.training:
            hidden_states = hidden_states + torch.normal(
                0.0, self.hidden_noise_std, size=hidden_states.size()).to(device)
        h = hidden_states.clone()
        # fc after RNN
        for i in range(len(self.actor_fc_after_rnn)):
            h = self.actor_fc_after_rnn[i](h)
            h = self.activation_function(h)
        return h, hidden_states

    def forward_critic(self, inputs, prev_hidden_states):
        h = inputs
        # fc before RNN
        for i in range(len(self.critic_fc_before_rnn)):
            h = self.critic_fc_before_rnn[i](h)
            h = self.activation_function(h)
        # RNN
        hidden_states, _ = self.critic_rnn(h, prev_hidden_states)
        # add noise during training
        if self.critic_rnn.training:
            hidden_states = hidden_states + torch.normal(
                0.0, self.hidden_noise_std, size=hidden_states.size()).to(device)
        h = hidden_states.clone()
        # fc after RNN
        for i in range(len(self.critic_fc_after_rnn)):
            h = self.critic_fc_after_rnn[i](h)
            h = self.activation_function(h)
        return h, hidden_states

    def init_hidden(self, batch_size, model):
        '''
        start out with a hidden state of all zeros
        ---
        model: str, "actor" or "critic"
        '''
        # TODO: add option to incorporate the initial state
        prior_hidden_states = torch.zeros((1, batch_size, self.rnn_hidden_dim), 
                                         requires_grad=True).to(device)
        h = prior_hidden_states
        # forward through fully connected layers after RNN
        if model == 'actor':
            for i in range(len(self.actor_fc_after_rnn)):
                h = F.relu(self.actor_fc_after_rnn[i](h))
            prior_output = self.actor_output(h)
        elif model == 'critic':
            for i in range(len(self.critic_fc_after_rnn)):
                h = F.relu(self.critic_fc_after_rnn[i](h))
            prior_output = self.critic_output(h)
        else:
            raise ValueError(f'model can only be actor or critic')

        return prior_output, prior_hidden_states

    def forward(
        self, 
        curr_states, prev_actions, prev_rewards,
        actor_prev_hidden_states, critic_prev_hidden_states,
        return_prior=False
    ):
        """
        Actions, states, rewards should be given in shape [sequence_len * batch_size * dim].
        For one-step predictions, sequence_len=1 and hidden_states!=None.
        (hidden_states = [hidden_actor, hidden_critic])
        For feeding in entire trajectories, sequence_len>1 and hidden_state=None.
        In either case, we may return embeddings of length sequence_len+1 
        if they include the prior.
        """
        # input shape: sequence_len x batch_size x feature_dim
        # extract features for states, actions, rewards
        if (self.state_embed_dim!=0) & (self.action_embed_dim!=0) & (self.reward_embed_dim!=0):
            hs = self.state_encoder(curr_states)
            ha = self.action_encoder(prev_actions)
            hr = self.reward_encoder(prev_rewards)
            h = torch.cat((hs, ha, hr), dim=-1)
        else:
            h = torch.cat((curr_states, prev_actions, prev_rewards), dim=-1)
        # print(f'input h: {h.shape}')

        # initialize hidden state
        # if hidden_states is none, start with the prior
        if (actor_prev_hidden_states is None) and (critic_prev_hidden_states is None):
            batch_size = curr_states.shape[1]
            prior_action_logits, actor_prior_hidden_states = self.init_hidden(batch_size, 'actor')
            prior_state_values, critic_prior_hidden_states = self.init_hidden(batch_size, 'critic')
            actor_prev_hidden_states = actor_prior_hidden_states.clone()
            critic_prev_hidden_states = critic_prior_hidden_states.clone()

        # forward through actor_critic
        actor_h, actor_hidden_states = self.forward_actor(h, actor_prev_hidden_states)
        critic_h, critic_hidden_states = self.forward_critic(h, critic_prev_hidden_states)
        # print(f'actor_hidden_states: {actor_hidden_states.shape}')
        # print(f'critic_hidden_states: {critic_hidden_states.shape}')

        # outputs
        action_logits = self.actor_output(actor_h)
        state_values = self.critic_output(critic_h)

        if return_prior:
            action_logits = torch.cat((prior_action_logits, action_logits))
            state_values = torch.cat((prior_state_values, state_values))
            actor_hidden_states = torch.cat((actor_prior_hidden_states, actor_hidden_states))
            critic_hidden_states = torch.cat((critic_prior_hidden_states, critic_hidden_states))
        
        return action_logits, state_values, actor_hidden_states, critic_hidden_states

    def act(
        self, 
        curr_states, prev_actions, prev_rewards, 
        actor_prev_hidden_states, critic_prev_hidden_states,
        return_prior=False, deterministic=False
    ):
        """
        Returns the (raw) actions and their value.
        """
        # forward once
        action_logits, state_values, actor_hidden_states, critic_hidden_states = self.forward(
            curr_states, prev_actions, prev_rewards,
            actor_prev_hidden_states, critic_prev_hidden_states,
            return_prior=return_prior
        )
        # sample action
        action_pd = self.policy_dist(logits=action_logits)
        if deterministic:
            if isinstance(action_pd, torch.distributions.Categorical):
                actions = action_pd.mode()
            else:
                actions = action_pd.mean
        else:
            actions = action_pd.sample()
        action_log_probs = action_pd.log_prob(actions)
        entropy = action_pd.entropy()

        return actions, action_log_probs, entropy, state_values, actor_hidden_states, critic_hidden_states


class ActorCriticRNN(nn.Module):
    def __init__(
        self,
        args,
        # network size
        layers_before_rnn, # list
        rnn_hidden_dim, # int
        layers_after_rnn, # list
        rnn_cell_type, # vanilla, gru
        activation_function,  # tanh, relu, leaky-relu
        initialization_method, # orthogonal, normc
        # states, actions, rewards
        state_dim,
        state_embed_dim,
        action_dim,
        action_embed_dim,
        action_space_type, # Discrete
        reward_dim,
        reward_embed_dim
    ):
        '''
        Separate single-layered RNNs for actor and critic
        '''
        super(ActorCriticRNN, self).__init__()
        
        self.args = args

        self.state_dim = state_dim
        self.action_dim = action_dim
        self.reward_dim = reward_dim
        self.state_embed_dim = state_embed_dim
        self.action_embed_dim = action_embed_dim
        self.reward_embed_dim = reward_embed_dim

        # set activation function
        if activation_function == 'tanh':
            self.activation_function = nn.Tanh()
        elif activation_function == 'relu':
            self.activation_function = nn.ReLU()
        elif activation_function == 'leaky-relu':
            self.activation_function = nn.LeakyReLU()
        else:
            raise ValueError

        # set initialization method
        if initialization_method == 'normc':
            self.init_ = lambda m: init(
                m, init_normc_, 
                lambda x: nn.init.constant_(x, 0), 
                nn.init.calculate_gain(activation_function)
            )
        elif initialization_method == 'orthogonal':
            self.init_ = lambda m: init(
                m, nn.init.orthogonal_, 
                lambda x: nn.init.constant_(x, 0), 
                nn.init.calculate_gain(activation_function)
            )

        # embedder for state, action, reward
        if (self.state_embed_dim!=0) & (self.action_embed_dim!=0) & (self.reward_embed_dim!=0):
            self.state_encoder = utl.FeatureExtractor(self.state_dim, self.state_embed_dim, F.relu)
            self.action_encoder = utl.FeatureExtractor(self.action_dim, self.action_embed_dim, F.relu)
            self.reward_encoder = utl.FeatureExtractor(self.reward_dim, self.reward_embed_dim, F.relu)
            curr_input_dim =  self.state_embed_dim + self.action_embed_dim + self.reward_embed_dim
        else:
            curr_input_dim = self.state_dim + self.action_dim + self.reward_dim

        # initialize actor and critic
        # fully connected layers before the recurrent cell
        self.actor_fc_before_rnn, actor_fc_before_rnn_final_dim = self.gen_fc_layers(
            layers_before_rnn, curr_input_dim
        )
        self.critic_fc_before_rnn, critic_fc_before_rnn_final_dim = self.gen_fc_layers(
            layers_before_rnn, curr_input_dim
        )

        # recurrent layer
        self.rnn_hidden_dim = rnn_hidden_dim
        if rnn_cell_type == 'vanilla':
            self.actor_rnn = nn.RNN(
                input_size=actor_fc_before_rnn_final_dim,
                hidden_size=self.rnn_hidden_dim,
                num_layers=1)
            self.critic_rnn = nn.RNN(
                input_size=critic_fc_before_rnn_final_dim,
                hidden_size=self.rnn_hidden_dim,
                num_layers=1)
        elif rnn_cell_type == 'gru':
            self.actor_rnn = nn.GRU(
                input_size=actor_fc_before_rnn_final_dim,
                hidden_size=self.rnn_hidden_dim,
                num_layers=1)
            self.critic_rnn = nn.GRU(
                input_size=critic_fc_before_rnn_final_dim,
                hidden_size=self.rnn_hidden_dim,
                num_layers=1)
        else:
            raise ValueError(f'invalid rnn_cell_type: {rnn_cell_type}')
        
        for name, param in self.actor_rnn.named_parameters():
            if 'bias' in name:
                nn.init.constant_(param, 0)
            elif 'weight' in name:
                nn.init.orthogonal_(param)
        for name, param in self.critic_rnn.named_parameters():
            if 'bias' in name:
                nn.init.constant_(param, 0)
            elif 'weight' in name:
                nn.init.orthogonal_(param)

        # fully connected layers after the recurrent cell
        curr_input_dim = self.rnn_hidden_dim
        self.actor_fc_after_rnn, actor_fc_after_rnn_final_dim = self.gen_fc_layers(
            layers_after_rnn, curr_input_dim
        )
        self.critic_fc_after_rnn, critic_fc_after_rnn_final_dim = self.gen_fc_layers(
            layers_after_rnn, curr_input_dim
        )
        
        # output layer
        self.actor_output = self.init_(nn.Linear(actor_fc_after_rnn_final_dim, action_dim))
        if action_space_type == 'Discrete':
            self.policy_dist = torch.distributions.Categorical
        else:
            raise NotImplementedError
        self.critic_output = self.init_(nn.Linear(critic_fc_after_rnn_final_dim, 1))

    def gen_fc_layers(self, layers, curr_input_dim):
        fc_layers = nn.ModuleList([])
        for i in range(len(layers)):
            fc = self.init_(nn.Linear(curr_input_dim, layers[i]))
            fc_layers.append(fc)
            curr_input_dim = layers[i]
        return fc_layers, curr_input_dim

    def forward_actor(self, inputs, prev_hidden_states):
        h = inputs
        # fc before RNN
        for i in range(len(self.actor_fc_before_rnn)):
            h = self.actor_fc_before_rnn[i](h)
            h = self.activation_function(h)
        # RNN
        hidden_states, _ = self.actor_rnn(h, prev_hidden_states) # rnn output: output, h_n
        h = hidden_states.clone()
        # fc after RNN
        for i in range(len(self.actor_fc_after_rnn)):
            h = self.actor_fc_after_rnn[i](h)
            h = self.activation_function(h)
        return h, hidden_states

    def forward_critic(self, inputs, prev_hidden_states):
        h = inputs
        # fc before RNN
        for i in range(len(self.critic_fc_before_rnn)):
            h = self.critic_fc_before_rnn[i](h)
            h = self.activation_function(h)
        # RNN
        hidden_states, _ = self.critic_rnn(h, prev_hidden_states)
        h = hidden_states.clone()
        # fc after RNN
        for i in range(len(self.critic_fc_after_rnn)):
            h = self.critic_fc_after_rnn[i](h)
            h = self.activation_function(h)
        return h, hidden_states

    def init_hidden(self, batch_size, model):
        '''
        start out with a hidden state of all zeros
        ---
        model: str, "actor" or "critic"
        '''
        # TODO: add option to incorporate the initial state
        prior_hidden_states = torch.zeros((1, batch_size, self.rnn_hidden_dim), 
                                         requires_grad=True).to(device)
        h = prior_hidden_states
        # forward through fully connected layers after RNN
        if model == 'actor':
            for i in range(len(self.actor_fc_after_rnn)):
                h = F.relu(self.actor_fc_after_rnn[i](h))
            prior_output = self.actor_output(h)
        elif model == 'critic':
            for i in range(len(self.critic_fc_after_rnn)):
                h = F.relu(self.critic_fc_after_rnn[i](h))
            prior_output = self.critic_output(h)
        else:
            raise ValueError(f'model can only be actor or critic')

        return prior_output, prior_hidden_states

    def forward(
        self, 
        curr_states, prev_actions, prev_rewards,
        actor_prev_hidden_states, critic_prev_hidden_states,
        return_prior=False
    ):
        """
        Actions, states, rewards should be given in shape [sequence_len * batch_size * dim].
        For one-step predictions, sequence_len=1 and hidden_states!=None.
        (hidden_states = [hidden_actor, hidden_critic])
        For feeding in entire trajectories, sequence_len>1 and hidden_state=None.
        In either case, we may return embeddings of length sequence_len+1 
        if they include the prior.
        """
        # input shape: sequence_len x batch_size x feature_dim
        # extract features for states, actions, rewards
        if (self.state_embed_dim!=0) & (self.action_embed_dim!=0) & (self.reward_embed_dim!=0):
            hs = self.state_encoder(curr_states)
            ha = self.action_encoder(prev_actions)
            hr = self.reward_encoder(prev_rewards)
            h = torch.cat((hs, ha, hr), dim=-1)
        else:
            h = torch.cat((curr_states, prev_actions, prev_rewards), dim=-1)
        # print(f'input h: {h.shape}')

        # initialize hidden state
        # if hidden_states is none, start with the prior
        if (actor_prev_hidden_states is None) and (critic_prev_hidden_states is None):
            batch_size = curr_states.shape[1]
            prior_action_logits, actor_prior_hidden_states = self.init_hidden(batch_size, 'actor')
            prior_state_values, critic_prior_hidden_states = self.init_hidden(batch_size, 'critic')
            actor_prev_hidden_states = actor_prior_hidden_states.clone()
            critic_prev_hidden_states = critic_prior_hidden_states.clone()

        # forward through actor_critic
        actor_h, actor_hidden_states = self.forward_actor(h, actor_prev_hidden_states)
        critic_h, critic_hidden_states = self.forward_critic(h, critic_prev_hidden_states)
        # print(f'actor_hidden_states: {actor_hidden_states.shape}')
        # print(f'critic_hidden_states: {critic_hidden_states.shape}')

        # outputs
        action_logits = self.actor_output(actor_h)
        state_values = self.critic_output(critic_h)

        if return_prior:
            action_logits = torch.cat((prior_action_logits, action_logits))
            state_values = torch.cat((prior_state_values, state_values))
            actor_hidden_states = torch.cat((actor_prior_hidden_states, actor_hidden_states))
            critic_hidden_states = torch.cat((critic_prior_hidden_states, critic_hidden_states))
        
        return action_logits, state_values, actor_hidden_states, critic_hidden_states

    def act(
        self, 
        curr_states, prev_actions, prev_rewards, 
        actor_prev_hidden_states, critic_prev_hidden_states,
        return_prior=False, deterministic=False
    ):
        """
        Returns the (raw) actions and their value.
        """
        # forward once
        action_logits, state_values, actor_hidden_states, critic_hidden_states = self.forward(
            curr_states, prev_actions, prev_rewards,
            actor_prev_hidden_states, critic_prev_hidden_states,
            return_prior=return_prior
        )
        # sample action
        action_pd = self.policy_dist(logits=action_logits)
        if deterministic:
            if isinstance(action_pd, torch.distributions.Categorical):
                actions = action_pd.mode()
            else:
                actions = action_pd.mean
        else:
            actions = action_pd.sample()
        action_log_probs = action_pd.log_prob(actions)
        entropy = action_pd.entropy()

        return actions, action_log_probs, entropy, state_values, actor_hidden_states, critic_hidden_states


class SharedActorCriticRNN(nn.Module):
    def __init__(self,
                 args,
                 # network size
                 layers_before_rnn, # list
                 rnn_hidden_dim, # int
                 layers_after_rnn, # list
                 rnn_cell_type, # vanilla, gru
                 activation_function,  # tanh, relu, leaky-relu
                 initialization_method, # orthogonal, normc
                 # states, actions, rewards
                 state_dim,
                 state_embed_dim,
                 action_dim,
                 action_embed_dim,
                 action_space_type, # Discrete
                 reward_dim,
                 reward_embed_dim
                 ):
        '''
        use shared single-layered RNNs for both actor and critic
        '''
        super(SharedActorCriticRNN, self).__init__()
        
        self.args = args

        self.state_dim = state_dim
        self.action_dim = action_dim
        self.reward_dim = reward_dim
        self.state_embed_dim = state_embed_dim
        self.action_embed_dim = action_embed_dim
        self.reward_embed_dim = reward_embed_dim

        # set activation function
        if activation_function == 'tanh':
            self.activation_function = nn.Tanh()
        elif activation_function == 'relu':
            self.activation_function = nn.ReLU()
        elif activation_function == 'leaky-relu':
            self.activation_function = nn.LeakyReLU()
        else:
            raise ValueError

        # set initialization method
        if initialization_method == 'normc':
            self.init_ = lambda m: init(m, init_normc_, 
                                        lambda x: nn.init.constant_(x, 0), 
                                        nn.init.calculate_gain(activation_function))
        elif initialization_method == 'orthogonal':
            self.init_ = lambda m: init(m, nn.init.orthogonal_, 
                                        lambda x: nn.init.constant_(x, 0), 
                                        nn.init.calculate_gain(activation_function))

        # embedder for state, action, reward
        if (self.state_embed_dim!=0) & (self.action_embed_dim!=0) & (self.reward_embed_dim!=0):
            self.state_encoder = utl.FeatureExtractor(self.state_dim, self.state_embed_dim, F.relu)
            self.action_encoder = utl.FeatureExtractor(self.action_dim, self.action_embed_dim, F.relu)
            self.reward_encoder = utl.FeatureExtractor(self.reward_dim, self.reward_embed_dim, F.relu)
            curr_input_dim =  self.state_embed_dim + self.action_embed_dim + self.reward_embed_dim
        else:
            curr_input_dim = self.state_dim + self.action_dim + self.reward_dim

        # initialize actor and critic
        # fully connected layers before the recurrent cell
        self.fc_before_rnn, fc_before_rnn_final_dim = self.gen_fc_layers(layers_before_rnn, 
                                                                         curr_input_dim)

        # recurrent layer
        self.rnn_hidden_dim = rnn_hidden_dim
        if rnn_cell_type == 'vanilla':
            self.shared_rnn = nn.RNN(
                input_size=fc_before_rnn_final_dim,
                hidden_size=self.rnn_hidden_dim,
                num_layers=1)
        elif rnn_cell_type == 'gru':
            self.shared_rnn = nn.GRU(
                input_size=fc_before_rnn_final_dim,
                hidden_size=self.rnn_hidden_dim,
                num_layers=1)
        else:
            raise ValueError(f'invalid rnn_cell_type: {rnn_cell_type}')
        
        for name, param in self.shared_rnn.named_parameters():
            if 'bias' in name:
                nn.init.constant_(param, 0)
            elif 'weight' in name:
                nn.init.orthogonal_(param)

        # fully connected layers after the recurrent cell
        curr_input_dim = self.rnn_hidden_dim
        self.fc_after_rnn, fc_after_rnn_final_dim = self.gen_fc_layers(layers_after_rnn, 
                                                                                   curr_input_dim)
        
        # output layer
        self.actor_output = self.init_(nn.Linear(fc_after_rnn_final_dim, action_dim))
        if action_space_type == 'Discrete':
            self.policy_dist = torch.distributions.Categorical
        else:
            raise NotImplementedError
        self.critic_output = self.init_(nn.Linear(fc_after_rnn_final_dim, 1))

    def gen_fc_layers(self, layers, curr_input_dim):
        fc_layers = nn.ModuleList([])
        for i in range(len(layers)):
            fc = self.init_(nn.Linear(curr_input_dim, layers[i]))
            fc_layers.append(fc)
            curr_input_dim = layers[i]
        return fc_layers, curr_input_dim

    def forward_shared_actor_critic(self, inputs, prev_hidden_states):
        # fc before RNN
        h = inputs
        for i in range(len(self.fc_before_rnn)):
            h = self.fc_before_rnn[i](h)
            h = self.activation_function(h)
        # RNN
        hidden_states, _ = self.shared_rnn(h, prev_hidden_states) # rnn output: output, h_n
        h = hidden_states.clone()
        # fc after RNN
        for i in range(len(self.fc_after_rnn)):
            h = self.fc_after_rnn[i](h)
            h = self.activation_function(h)
        
        return h, hidden_states

    def init_hidden(self, batch_size):
        '''
        start out with a hidden state of all zeros
        ---
        model: str, "actor" or "critic"
        '''
        # TODO: add option to incorporate the initial state
        prior_hidden_states = torch.zeros((1, batch_size, self.rnn_hidden_dim), 
                                          requires_grad=True).to(device)
        
        # forward through fully connected layers after RNN
        h = prior_hidden_states
        for i in range(len(self.fc_after_rnn)):
            h = F.relu(self.actor_fc_after_rnn[i](h))
    
        prior_action_logits = self.actor_output(h)
        prior_state_values = self.critic_output(h)

        return prior_action_logits, prior_state_values, prior_hidden_states

    def forward(self, 
                curr_states, prev_actions, prev_rewards,
                rnn_prev_hidden_states,
                return_prior=False):
        """
        Actions, states, rewards should be given in shape [sequence_len * batch_size * dim].
        For one-step predictions, sequence_len=1 and hidden_states!=None.
        (hidden_states = [hidden_actor, hidden_critic])
        For feeding in entire trajectories, sequence_len>1 and hidden_state=None.
        In either case, we may return embeddings of length sequence_len+1 
        if they include the prior.
        """
        # input shape: sequence_len x batch_size x feature_dim
        # extract features for states, actions, rewards
        if (self.state_embed_dim!=0) & (self.action_embed_dim!=0) & (self.reward_embed_dim!=0):
            hs = self.state_encoder(curr_states)
            ha = self.action_encoder(prev_actions)
            hr = self.reward_encoder(prev_rewards)
            h = torch.cat((hs, ha, hr), dim=-1)
        else:
            h = torch.cat((curr_states, prev_actions, prev_rewards), dim=-1)
        # print(f'input h: {h.shape}')

        # initialize hidden state
        # if hidden_states is none, start with the prior
        if rnn_prev_hidden_states is None:
            batch_size = curr_states.shape[1]
            prior_action_logits, prior_state_values, rnn_prior_hidden_states = self.init_hidden(batch_size)
            rnn_prev_hidden_states = rnn_prior_hidden_states.clone()

        # forward through shared actor_critic network
        shared_actor_critic_h, rnn_hidden_states = self.forward_shared_actor_critic(h, rnn_prev_hidden_states)

        # outputs
        action_logits = self.actor_output(shared_actor_critic_h)
        state_values = self.critic_output(shared_actor_critic_h)

        if return_prior:
            action_logits = torch.cat((prior_action_logits, action_logits))
            state_values = torch.cat((prior_state_values, state_values))
            rnn_hidden_states = torch.cat((rnn_prior_hidden_states, rnn_hidden_states))
        
        return action_logits, state_values, rnn_hidden_states

    def act(self, 
            curr_states, prev_actions, prev_rewards, 
            rnn_prev_hidden_states,
            return_prior=False, deterministic=False):
        """
        Returns the (raw) actions and their value.
        """
        # forward once
        action_logits, state_values, rnn_hidden_states = self.forward(
            curr_states, prev_actions, prev_rewards,
            rnn_prev_hidden_states,
            return_prior=return_prior
        )
        # sample action
        action_pd = self.policy_dist(logits=action_logits)
        if deterministic:
            if isinstance(action_pd, torch.distributions.Categorical):
                actions = action_pd.mode()
            else:
                actions = action_pd.mean
        else:
            actions = action_pd.sample()
        action_log_probs = action_pd.log_prob(actions)
        entropy = action_pd.entropy()

        return actions, action_log_probs, entropy, state_values, rnn_hidden_states

import torch
from torch import nn, jit
import math


class nmRNNCell_base(jit.ScriptModule): #(nn.Module):#
#     __constants__ = ['bias']
    
    def __init__(self, N_NM, input_size, hidden_size, nonlinearity, bias, keepW0 = False):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.nonlinearity = nonlinearity
        self.N_nm = N_NM
        self.keepW0 = keepW0
        self.g = 10

        self.weight_ih = nn.Parameter(torch.Tensor(hidden_size, input_size))
        #self.weight_ih = nn.Parameter(torch.Tensor(hidden_size, input_size, N_NM))
        self.weight_hh = nn.Parameter(torch.Tensor(hidden_size, hidden_size, N_NM))
        self.weight_h2nm = nn.Parameter(torch.Tensor(N_NM, hidden_size))
        self.weight_nm2nm = nn.Parameter(torch.Tensor(N_NM, N_NM))
        if keepW0:
            self.weight0_hh = nn.Parameter(torch.Tensor(hidden_size, hidden_size), requires_grad = True)
        else:
            self.weight0_hh = nn.Parameter(torch.Tensor(hidden_size, hidden_size), requires_grad = False)
            #self.register_parameter('weight0_hh', None)
        
        if bias:
            self.bias = nn.Parameter(torch.Tensor(hidden_size))
        else:
            self.register_parameter('bias', None)
            
        self.reset_parameters()
        
    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight_ih, a=math.sqrt(5))    #, nonlinearity=nonlinearity)
        nn.init.kaiming_uniform_(self.weight_hh, a=self.g/math.sqrt(self.hidden_size))    #, nonlinearity=nonlinearity)
        nn.init.sparse_(self.weight_h2nm, 0.1)
        nn.init.zeros_(self.weight_nm2nm)

        if self.keepW0:
            nn.init.kaiming_uniform_(self.weight0_hh, a=math.sqrt(5))
        else:
            nn.init.zeros_(self.weight0_hh)
        
        if self.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight_ih)
            bound = 1 / math.sqrt(fan_in)
            nn.init.uniform_(self.bias, -bound, bound)


            
class nmRNNCell(nmRNNCell_base):  # Euler integration of rate-neuron network dynamics 
    def __init__(self, N_nm, input_size, hidden_size, nonlinearity = None, decay = 0, bias = True, keepW0 = True):
        super().__init__(N_nm, input_size, hidden_size, nonlinearity, bias)
        self.decay = decay    #  torch.exp( - dt/tau )
        self.N_nm = N_nm

    def forward(self, input, hiddenCombined):
        # start by disentangling the NMs from the Hidden Units
        if self.N_nm>0:
            hidden = hiddenCombined[:,:,0:-self.N_nm]
            nm = hiddenCombined[:,:,-self.N_nm::]
            #print(hiddenCombined.shape, hidden.shape, nm.shape)
        else:
            hidden = hiddenCombined
            nm = None
        if self.bias == None:
            if nm != None:
                activity = self.nonlinearity(input @ self.weight_ih.t() + torch.einsum('tbj, ijk, tbk -> bi', hidden, self.weight_hh, nm) + hidden @ self.weight0_hh.t())
            else:
                activity = self.nonlinearity(input @ self.weight_ih.t() +  hidden @ self.weight0_hh.t())
        else:
            if nm != None:
                activity = self.nonlinearity(input @ self.weight_ih.t() + torch.einsum('bj, ijk, bk -> bi', hidden, self.weight_hh, nm) + hidden @ self.weight0_hh.t() + self.bias)
            else:
                activity = self.nonlinearity(input @ self.weight_ih.t() +  hidden @ self.weight0_hh.t() + self.bias)
        if nm != None:
            activity_nm = self.nonlinearity(hidden @ self.weight_h2nm.t() + nm @ self.weight_nm2nm.t())
            nm = self.decay * nm + (1-self.decay) * activity_nm
        hidden   = self.decay * hidden + (1 - self.decay) * activity
        return torch.cat([hidden, nm], dim = 2)

class nmRNNLayer(nn.Module):    
    def __init__(self, N_nm, input_size, hidden_size, nonlinearity, decay = 0.9, bias = False, keepW0 = False):
        super().__init__()
        self.rnncell = nmRNNCell(N_nm, input_size, hidden_size, nonlinearity = nonlinearity, decay = decay, bias = bias, keepW0 = keepW0)
        self.N_nm = N_nm

    def forward(self, input, initH):
        #print('in the layer ', initH[0].shape, initH[1].shape)
        inputs = input.unbind(0)     # inputs has dimension [Time, batch n_input]
        hidden = initH      # initial state has dimension [1, batch, n_rnn]
        outputs = []
        nm_out = []
        for i in range(len(inputs)):  # looping over the time dimension 
            hidden = self.rnncell(inputs[i], hidden)
            outputs += [hidden.squeeze(0)]       # vanilla RNN directly outputs the hidden state
        #print('line602', torch.stack(outputs).shape, hidden.shape)
        return torch.stack(outputs), hidden



class nmActorCriticRNN(nn.Module):
    def __init__(self,
                 args,
                 # network size
                 layers_before_rnn, # list
                 rnn_hidden_dim, # int
                 layers_after_rnn, # list
                 rnn_cell_type, # vanilla, gru, nm
                 activation_function,  # tanh, relu, leaky-relu
                 initialisation_method, # orthogonal, normc
                 # states, actions, rewards
                 state_dim,
                 state_embed_dim,
                 action_dim,
                 action_embed_dim,
                 action_space_type, # Discrete
                 reward_dim,
                 reward_embed_dim,
                 N_nm = 4
                 ):
        '''
        use shared single-layered RNNs for both actor and critic
        '''
        super(nmActorCriticRNN, self).__init__()
        
        self.args = args

        self.state_dim = state_dim
        self.action_dim = action_dim
        self.reward_dim = reward_dim
        self.state_fembed_dim = state_embed_dim
        self.action_embed_dim = action_embed_dim
        self.reward_embed_dim = reward_embed_dim
        self.N_nm = N_nm
        rnn_cell_type = 'nm'   #overwrite this 

        # set activation function
        if activation_function == 'tanh':
            self.activation_function = nn.Tanh()
        elif activation_function == 'relu':
            self.activation_function = nn.ReLU()
        elif activation_function == 'leaky-relu':
            self.activation_function = nn.LeakyReLU()
        else:
            raise ValueError

        # set initialization method
        if initialisation_method == 'normc':
            self.init_ = lambda m: init(m, init_normc_, 
                                        lambda x: nn.init.constant_(x, 0), 
                                        nn.init.calculate_gain(activation_function))
        elif initialisation_method == 'orthogonal':
            self.init_ = lambda m: init(m, nn.init.orthogonal_, 
                                        lambda x: nn.init.constant_(x, 0), 
                                        nn.init.calculate_gain(activation_function))

        # embedder for state, action, reward
        if (self.state_embed_dim!=0) & (self.action_embed_dim!=0) & (self.reward_embed_dim!=0):
            self.state_encoder = utl.FeatureExtractor(self.state_dim, self.state_embed_dim, F.relu)
            self.action_encoder = utl.FeatureExtractor(self.action_dim, self.action_embed_dim, F.relu)
            self.reward_encoder = utl.FeatureExtractor(self.reward_dim, self.reward_embed_dim, F.relu)
            curr_input_dim =  self.state_embed_dim + self.action_embed_dim + self.reward_embed_dim
        else:
            curr_input_dim = self.state_dim + self.action_dim + self.reward_dim

        # initialize actor and critic
        # fully connected layers before the recurrent cell
        self.fc_before_rnn, fc_before_rnn_final_dim = self.gen_fc_layers(layers_before_rnn, 
                                                                         curr_input_dim)

        # recurrent layer
        self.rnn_hidden_dim = rnn_hidden_dim
        self.shared_rnn = nmRNNLayer(
            self.N_nm,
            fc_before_rnn_final_dim,
            self.rnn_hidden_dim,
            self.activation_function)        
        #else:
        #    raise ValueError(f'invalid rnn_cell_type: {rnn_cell_type}')
        
        for name, param in self.shared_rnn.named_parameters():
            if 'bias' in name:
                nn.init.constant_(param, 0)
            elif 'weight' in name:
                nn.init.orthogonal_(param)

        # fully connected layers after the recurrent cell
        curr_input_dim = self.rnn_hidden_dim
        self.fc_after_rnn, fc_after_rnn_final_dim = self.gen_fc_layers(layers_after_rnn, 
                                                                                   curr_input_dim)
        
        # output layer
        self.actor_output = self.init_(nn.Linear(fc_after_rnn_final_dim, action_dim))
        if action_space_type == 'Discrete':
            self.policy_dist = torch.distributions.Categorical
        else:
            raise NotImplementedError
        self.critic_output = self.init_(nn.Linear(fc_after_rnn_final_dim, 1))

    def gen_fc_layers(self, layers, curr_input_dim):
        fc_layers = nn.ModuleList([])
        for i in range(len(layers)):
            fc = self.init_(nn.Linear(curr_input_dim, layers[i]))
            fc_layers.append(fc)
            curr_input_dim = layers[i]
        return fc_layers, curr_input_dim

    def forward_shared_actor_critic(self, inputs, prev_hidden_states):
        # fc before RNN
        h = inputs
        for i in range(len(self.fc_before_rnn)):
            h = self.fc_before_rnn[i](h)
            h = self.activation_function(h)
        # RNN
        #print(prev_hidden_states, prev_hidden_states.shape)
        #print(h, h.shape)
        #******Possible Bug below *******
        hidden_states, _ = self.shared_rnn(h, prev_hidden_states) # rnn output: output, h_n
        h = hidden_states.clone()[:,:,0:-self.N_nm]
        #print('line 724 ', hidden_states.shape, h.shape)
        # fc after RNN
        for i in range(len(self.fc_after_rnn)):
            h = self.actor_fc_after_rnn[i](h)
            h = self.activation_function(h)
        #print(h.shape)
        return h, hidden_states

    def init_hidden(self, batch_size):
        '''
        start out with a hidden state of all zeros
        ---
        model: str, "actor" or "critic"
        '''
        # TODO: add option to incorporate the initial state
        prior_hidden_states = torch.zeros((1, batch_size, self.rnn_hidden_dim + self.N_nm), 
                                          requires_grad=True).to(device)
        
        # forward through fully connected layers after RNN
        h = prior_hidden_states[:,:,0:-self.N_nm]
        for i in range(len(self.fc_after_rnn)):
            h = F.relu(self.actor_fc_after_rnn[i](h))
    
        prior_action_logits = self.actor_output(h)
        prior_state_values = self.critic_output(h)

        #prior_hiddenNM_states = [prior_hidden_states, prior_nm_states]

        #print('look here', prior_hiddenNM_states, prior_hiddenNM_states[0].shape, prior_hiddenNM_states[1].shape)

        return prior_action_logits, prior_state_values, prior_hidden_states

    def forward(self, 
                curr_states, prev_actions, prev_rewards,
                rnn_prev_hidden_states,
                return_prior=False):
        """
        Actions, states, rewards should be given in shape [sequence_len * batch_size * dim].
        For one-step predictions, sequence_len=1 and hidden_states!=None.
        (hidden_states = [hidden_actor, hidden_critic])
        For feeding in entire trajectories, sequence_len>1 and hidden_state=None.
        In either case, we may return embeddings of length sequence_len+1 
        if they include the prior.
        """
        # input shape: sequence_len x batch_size x feature_dim
        # extract features for states, actions, rewards
        if (self.state_embed_dim!=0) & (self.action_embed_dim!=0) & (self.reward_embed_dim!=0):
            hs = self.state_encoder(curr_states)
            ha = self.action_encoder(prev_actions)
            hr = self.reward_encoder(prev_rewards)
            h = torch.cat((hs, ha, hr), dim=-1)
        else:
            h = torch.cat((curr_states, prev_actions, prev_rewards), dim=-1)
        # print(f'input h: {h.shape}')

        # initialize hidden state
        # if hidden_states is none, start with the prior
        if rnn_prev_hidden_states is None:
            batch_size = curr_states.shape[1]
            prior_action_logits, prior_state_values, rnn_prior_hidden_states = self.init_hidden(batch_size)
            rnn_prev_hidden_states = rnn_prior_hidden_states
        #print('line 787', rnn_prev_hidden_states.shape)

        # forward through shared actor_critic network
        #print('rnn prev size after init', rnn_prev_hidden_states[0].shape, rnn_prev_hidden_states[1].shape)
        shared_actor_critic_h, rnn_hidden_states = self.forward_shared_actor_critic(h, rnn_prev_hidden_states)

        # outputs
        #print('line 794', shared_actor_critic_h.shape)
        action_logits = self.actor_output(shared_actor_critic_h)
        state_values = self.critic_output(shared_actor_critic_h)  # [:,:,0:-self.N_nm]

        if return_prior:
            action_logits = torch.cat((prior_action_logits, action_logits))
            state_values = torch.cat((prior_state_values, state_values))
            rnn_hidden_states = torch.cat((rnn_prior_hidden_states, rnn_hidden_states))
        
        return action_logits, state_values, rnn_hidden_states

    def act(self, 
            curr_states, prev_actions, prev_rewards, 
            rnn_prev_hidden_states,
            return_prior=False, deterministic=False):
        """
        Returns the (raw) actions and their value.
        """
        # forward once
        action_logits, state_values, rnn_hidden_states = self.forward(
            curr_states, prev_actions, prev_rewards,
            rnn_prev_hidden_states,
            return_prior=return_prior
        )
        # sample action
        action_pd = self.policy_dist(logits=action_logits)
        if deterministic:
            if isinstance(action_pd, torch.distributions.Categorical):
                actions = action_pd.mode()
            else:
                actions = action_pd.mean
        else:
            actions = action_pd.sample()
        action_log_probs = action_pd.log_prob(actions)
        entropy = action_pd.entropy()
        #print('line 829', rnn_hidden_states.shape)

        return actions, action_log_probs, entropy, state_values, rnn_hidden_states   #.squeeze(0)



